🚨 EP: fix EP router contract for many models + honor FP8 scale format#46818
🚨 EP: fix EP router contract for many models + honor FP8 scale format#46818IlyasMoutawwakil wants to merge 28 commits into
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| return Fp8Quantize(self.hf_quantizer) | ||
|
|
||
|
|
||
| class Fp8DecodeScale(ConversionOps): |
There was a problem hiding this comment.
any ideas as to why this part was dropped ?
There was a problem hiding this comment.
because i added support for ue8m0 scales in finegrained-fp8 v3, this was needed for minimax m3 with the v2, but not anymore, it also wastes memory
There was a problem hiding this comment.
ue8m0 scales are a bit messy, some store them in the correct torch dtype, some store them in uint8, and some even store them in fp32 for no special reason 😭 i'm trying to tighten the contract and honor the config all the times because supporting all the on-disk variations would be more complicated
There was a problem hiding this comment.
okay ! Just to be sure, if we remove it now, it would not break existing checkpoints that are in mxpf8 format right ?
There was a problem hiding this comment.
no they will work fine, even better because I just noticed that the fp32 scales are even avoiding the optimized mxfp8 path in https://github.com/huggingface/kernels-community/blob/aeb8ef0e09a132a6583c0a4c8b1096292922b54a/finegrained-fp8/torch-ext/finegrained_fp8/utils.py#L64 I also ran minimax m3 integration tests on the b200
There was a problem hiding this comment.
Yep sounds good, just require the version of the kernel for that path to error out properly if kernel version not installed
There was a problem hiding this comment.
we do pin the v3 in our lazy loading
| intermediate_size = config.moe_intermediate_size * config.n_shared_experts | ||
| self.shared_experts = DeepseekOcr2TextMLP(config=config, intermediate_size=intermediate_size) | ||
| self.n_routed_experts = config.n_routed_experts | ||
| self.num_experts = config.n_routed_experts |
There was a problem hiding this comment.
yeh i guess we can drop n_routed_experts, removing it
There was a problem hiding this comment.
hmm so it seems to cascade into many models
There was a problem hiding this comment.
maybe better to add _skip_if_ep_not_supported here instead of within test_ep_*?
vasqu
left a comment
There was a problem hiding this comment.
Only checked the modeling parts re modular and the models themself. It is slightly breaking technically because we move parts around modules so let's add 🚨
Generally aligned with this, just a bit unsure about the minimax m3 change - are we keeping everything as is without dequanting and then only convertin after all conversions? Not sure I can follow there 100%
vasqu
left a comment
There was a problem hiding this comment.
Just some quick comments, ping me when it's ready for another review. Seems like some comments were resolved but not addressed?
| if not self._ep_plan: | ||
| raise ValueError( | ||
| f"Expert parallelism was requested (`enable_expert_parallel=True`), but " | ||
| f"`{self.__class__.__name__}` does not define an expert-parallel plan. Add a " | ||
| f"`base_model_ep_plan` to its config, or disable expert parallelism." | ||
| ) |
There was a problem hiding this comment.
loud failure on missing ep plan
| index_head_dim: int = 128 | ||
| index_n_heads: int = 64 | ||
| mlp_bias: bool = False | ||
| num_experts: int = 256 |
There was a problem hiding this comment.
this was confusing
| def _process_model_after_weight_loading(self, model, **kwargs): | ||
| # dsv4-flash-base stores its (power-of-two) ue8m0 scales in a float32 container under | ||
| # `.scale`; those renamed keys keep the on-disk float32 dtype, so cast them to the UE8M0 | ||
| # dtype the kernels expect (exact, since the values are powers of two). Checkpoints that | ||
| # already ship the native float8 E8M0 dtype (e.g. dsv4-flash) are left untouched. | ||
| if self.quantization_config.scale_fmt == "ue8m0": | ||
| from ..integrations.finegrained_fp8 import _get_ue8m0_dtype | ||
|
|
||
| ue8m0 = _get_ue8m0_dtype() | ||
| float32_scales = [ | ||
| name | ||
| for name, param in model.named_parameters() | ||
| if name.endswith("_scale_inv") and param.dtype == torch.float32 | ||
| ] | ||
| for name in float32_scales: | ||
| module_name, _, attr = name.rpartition(".") | ||
| module = model.get_submodule(module_name) | ||
| scale = getattr(module, attr) | ||
| setattr(module, attr, torch.nn.Parameter(scale.data.to(ue8m0), requires_grad=False)) | ||
| return model |
There was a problem hiding this comment.
either like this or by hooking a quantization op to the scale rename op
There was a problem hiding this comment.
okay the second option didn't work
There was a problem hiding this comment.
I kinda prefer with a fp8DecodeScale
There was a problem hiding this comment.
why does it have to be post proc?
There was a problem hiding this comment.
Fp8DecodeScale did the opposite, and it targeted mxfp8 where it converted truly ue8m0 to fp32,
this is for for dsv4-flash-base, we need the opposite, ie convert fp32 to ue8m0 to honor the config scale_fmt (because for some reason they stored their ue8ù0 scales in fp32😭), that way we avoid casting, with a new mem allocation, at the entry of each kernel.
There was a problem hiding this comment.
why does it have to be post proc?
because the rename catches the dsv4 flash base scales first
vasqu
left a comment
There was a problem hiding this comment.
Ok so I think this looks overall good now, just a few smaller comments. Sometimes we add an attribute mapping so that all variations are kind of covered, not sure if we really need it for all models (would just double check)
The quants re minimax m3 were checked re dequant and quant so I think we are good with the changes but would like to hear @ArthurZucker's opinion on those related changes
|
Let's also update the PR description please so we summarize the changes a bit
|
ArthurZucker
left a comment
There was a problem hiding this comment.
LGTM, let's make sure kernel V is enforced
| return Fp8Quantize(self.hf_quantizer) | ||
|
|
||
|
|
||
| class Fp8DecodeScale(ConversionOps): |
There was a problem hiding this comment.
Yep sounds good, just require the version of the kernel for that path to error out properly if kernel version not installed
| if self.layer_types is None: | ||
| self.layer_types = ["deepseek_sparse_attention"] * self.num_hidden_layers | ||
|
|
||
| if (num_experts := kwargs.get("num_experts")) is not None: |
There was a problem hiding this comment.
mmm is this really something we want? let's not warn no?
There was a problem hiding this comment.
We had 2 values n_routed_experts and num_experts so it's for BC in any case a user explicitly sets this
There was a problem hiding this comment.
Removed the warning, it could indeed trigger unnecessarily
| "layers.*.mlp.experts.gate_up_proj": "grouped_gemm", | ||
| "layers.*.mlp.experts.gate_up_proj_scale_inv": "grouped_gemm", | ||
| "layers.*.mlp.experts.down_proj": "grouped_gemm", | ||
| "layers.*.mlp.experts.down_proj_scale_inv": "grouped_gemm", | ||
| "layers.*.mlp.experts": "moe_tp_experts", |
There was a problem hiding this comment.
ow shit IDK how this slipped in !
| del self.topk_method | ||
| self.norm_topk_prob = config.norm_topk_prob | ||
|
|
||
| def forward(self, hidden_states): |
There was a problem hiding this comment.
we can probably push standards but its fine
There was a problem hiding this comment.
(meaning other models do this as well exactly potentially?)
There was a problem hiding this comment.
Not sure what you mean here? It's the same as dsv2 (with a slightly different forward --> no norming at the end of the probs)
| def _process_model_after_weight_loading(self, model, **kwargs): | ||
| # dsv4-flash-base stores its (power-of-two) ue8m0 scales in a float32 container under | ||
| # `.scale`; those renamed keys keep the on-disk float32 dtype, so cast them to the UE8M0 | ||
| # dtype the kernels expect (exact, since the values are powers of two). Checkpoints that | ||
| # already ship the native float8 E8M0 dtype (e.g. dsv4-flash) are left untouched. | ||
| if self.quantization_config.scale_fmt == "ue8m0": | ||
| from ..integrations.finegrained_fp8 import _get_ue8m0_dtype | ||
|
|
||
| ue8m0 = _get_ue8m0_dtype() | ||
| float32_scales = [ | ||
| name | ||
| for name, param in model.named_parameters() | ||
| if name.endswith("_scale_inv") and param.dtype == torch.float32 | ||
| ] | ||
| for name in float32_scales: | ||
| module_name, _, attr = name.rpartition(".") | ||
| module = model.get_submodule(module_name) | ||
| scale = getattr(module, attr) | ||
| setattr(module, attr, torch.nn.Parameter(scale.data.to(ue8m0), requires_grad=False)) | ||
| return model |
There was a problem hiding this comment.
I kinda prefer with a fp8DecodeScale
| def _process_model_after_weight_loading(self, model, **kwargs): | ||
| # dsv4-flash-base stores its (power-of-two) ue8m0 scales in a float32 container under | ||
| # `.scale`; those renamed keys keep the on-disk float32 dtype, so cast them to the UE8M0 | ||
| # dtype the kernels expect (exact, since the values are powers of two). Checkpoints that | ||
| # already ship the native float8 E8M0 dtype (e.g. dsv4-flash) are left untouched. | ||
| if self.quantization_config.scale_fmt == "ue8m0": | ||
| from ..integrations.finegrained_fp8 import _get_ue8m0_dtype | ||
|
|
||
| ue8m0 = _get_ue8m0_dtype() | ||
| float32_scales = [ | ||
| name | ||
| for name, param in model.named_parameters() | ||
| if name.endswith("_scale_inv") and param.dtype == torch.float32 | ||
| ] | ||
| for name in float32_scales: | ||
| module_name, _, attr = name.rpartition(".") | ||
| module = model.get_submodule(module_name) | ||
| scale = getattr(module, attr) | ||
| setattr(module, attr, torch.nn.Parameter(scale.data.to(ue8m0), requires_grad=False)) | ||
| return model |
There was a problem hiding this comment.
why does it have to be post proc?
| parallelism = "Expert" if expert_parallel else "Tensor" | ||
| # An EP-capable MoE (@use_experts_implementation) must ship an ep_plan; assert before any | ||
| # skip so a plan-less model fails even where the parallel test can't run (GPU, old torch). | ||
| if expert_parallel and self._get_tp_model_class()._can_set_experts_implementation(): |
There was a problem hiding this comment.
perfect, we want good default EP plan evailable
There was a problem hiding this comment.
yeah, we can also make use_experts_impl take care of adding the ep_plan to the config at model init time for example
|
[For maintainers] Suggested jobs to run (before merge) run-slow: afmoe, cohere2_moe, deepseek_ocr2, deepseek_v2, deepseek_v3, deepseek_v32, dots1, ernie4_5_moe, ernie4_5_vl_moe, exaone_moe, flex_olmo, glm4_moe, glm4_moe_lite, glm4v_moe, glm_moe_dsa, hunyuan_v1_moe |
|
CI Dashboard: View test results in Grafana |
|
Need to check whether we need to update various conversion mappings; so withholding to merge for now |
What does this PR do?
Fixes # (issue)
Code Agent Policy
The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.
PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.
This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read
CONTRIBUTING.md.Before submitting
Pull Request checks?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.